
# Copyright (C) 2016-present the asyncpg authors and contributors
# <see AUTHORS file>
#
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0


import hashlib
import hmac
from hashlib import pbkdf2_hmac


@cython.final
cdef class RFC5802Authentication:
    """Contains the protocol for RFC5802 SCRAM authentication mechanism.

    This class implements the RFC5802 SCRAM authentication mechanism for secure
    password authentication. The implementation follows the defined protocol
    which can be referenced from:

    https://tools.ietf.org/rfc/rfc5802.txt

    The RFC5802 algorithm works as follows:

    - The client generates a salted password using PBKDF2 with a provided salt
      and iteration count
    - Server and client keys are derived from the salted password using HMAC
    - A stored key is generated by hashing the client key
    - Client signature is computed using HMAC of the stored key and auth token
    - Client proof is generated by XORing the client signature with client key
    - The server can verify the client proof without storing the actual password

    This implementation supports both SHA256 and SM3 hash algorithms, though
    SM3 is substituted with SHA256 in the Python standard library implementation.

    The beauty of this mechanism is that the actual password is never transmitted
    over the wire, only cryptographic proofs derived from it.
    """
    
    DEFAULT_ITERATIONS = 10000
    DEFAULT_KEY_LENGTH = 32
    SUPPORTED_METHODS = ["sha256", "sm3"]
    HEX_CHARS = "0123456789ABCDEF"
    HEX_LOOKUP = b'0123456789abcdef'
    
    def __cinit__(self):
        pass

    cdef  hex_string_to_bytes(self, str hex_string):
        """Convert hexadecimal string to bytes.

        Args:
            hex_string (str): Hexadecimal string to convert

        Returns:
            bytes: Converted bytes object

        Examples:
            >>> auth = RFC5802Authentication()
            >>> auth.hex_string_to_bytes("48656c6c6f")
            b'Hello'
        """
        if not hex_string:
            return b''
        
        cdef str upper_string = hex_string.upper()
        cdef int bytes_len = len(upper_string) // 2
        cdef bytearray result = bytearray(bytes_len)
        cdef int i, pos
        cdef str high_char, low_char
        cdef int high_val, low_val
        
        for i in range(bytes_len):
            pos = i * 2
            high_char = upper_string[pos]
            low_char = upper_string[pos + 1]
            
            # Convert characters to numeric values
            high_val = self.HEX_CHARS.index(high_char)
            low_val = self.HEX_CHARS.index(low_char)
            
            result[i] = (high_val << 4) | low_val
        
        return bytes(result)

    cdef  _generate_k_from_pbkdf2(self, str password, str random64code, 
                                      int server_iteration):
        """Generate key from PBKDF2 using SHA1 algorithm.

        Note: This function uses SHA1, not SHA256 as per original implementation
        requirements.

        Args:
            password (str): Password string
            random64code (str): 64-character random hex string
            server_iteration (int): Number of PBKDF2 iterations

        Returns:
            bytes: Generated key bytes (32 bytes)
        """
        cdef bytes random32code = self.hex_string_to_bytes(random64code)
        # Using 'sha1' as per original code requirements
        cdef bytes pwd_encoded = pbkdf2_hmac(
            'sha1', 
            password.encode('utf-8'), 
            random32code, 
            server_iteration, 
            self.DEFAULT_KEY_LENGTH
        )
        return pwd_encoded

    cdef  _bytes_to_hex_string(self, bytes src):
        """Convert bytes to hexadecimal string.

        Args:
            src (bytes): Source bytes to convert

        Returns:
            str: Hexadecimal string representation
        """
        cdef str s = ""
        cdef int byte_val, v
        cdef str hv
        
        for byte_val in src:
            v = byte_val & 0xFF
            hv = format(v, 'x')
            if len(hv) < 2:
                s += "0" + hv
            else:
                s += hv
        return s

    cdef  _get_key_from_hmac(self, bytes key, bytes data):
        """Generate HMAC key using SHA256.

        Args:
            key (bytes): HMAC key
            data (bytes): Data to be authenticated

        Returns:
            bytes: HMAC digest
        """
        cdef object h = hmac.new(key, data, hashlib.sha256)
        return h.digest()

    cdef  _get_sha256(self, bytes message):
        """Calculate SHA256 hash of message.

        Args:
            message (bytes): Message to hash

        Returns:
            bytes: SHA256 hash digest
        """
        cdef object hash_obj = hashlib.sha256()
        hash_obj.update(message)
        return hash_obj.digest()

    cdef  _get_sm3(self, bytes message):
        """Calculate SM3 hash of message.

        Note: This uses SHA256 as substitute since Python standard library 
        doesn't include SM3. In production, use gmssl library for actual SM3.

        Args:
            message (bytes): Message to hash

        Returns:
            bytes: Hash digest (SHA256 as SM3 substitute)
        """
        cdef object hash_obj = hashlib.sha256()  
        hash_obj.update(message)
        return hash_obj.digest()

    cdef  _xor_between_password(self, bytes password1, bytes password2, int length):
        """Perform XOR operation between two password bytes.

        Corresponds to Go's XorBetweenPassword function.

        Args:
            password1 (bytes): First password bytes
            password2 (bytes): Second password bytes
            length (int): Length of bytes to XOR

        Returns:
            bytes: XOR result
        """
        cdef bytearray result = bytearray(length)
        cdef int i
        
        for i in range(length):
            result[i] = password1[i] ^ password2[i]
        return bytes(result)

    cdef  _bytes_to_hex(self, bytes source_bytes, bytearray result_array=None, 
                            int start_pos=0, int length=-1):
        """Convert bytes to hexadecimal format.

        Args:
            source_bytes (bytes): Source bytes to convert
            result_array (bytearray, optional): Target array for result. 
                If None, creates new array.
            start_pos (int, optional): Starting position in result array. 
                Default is 0.
            length (int, optional): Number of bytes to convert. 
                If -1, converts all bytes.

        Returns:
            bytes or bytearray: Hexadecimal representation
        """
        cdef int pos, i, c, j
        cdef int byte_val
        cdef bytearray result
        
        if result_array is not None:
            if length == -1:
                length = len(source_bytes)
            
            pos = start_pos
            
            for i in range(length):
                if i >= len(source_bytes):
                    break
                byte_val = source_bytes[i]
                c = int(byte_val & 0xFF)
                j = c >> 4
                result_array[pos] = self.HEX_LOOKUP[j]
                pos += 1
                j = c & 0xF
                result_array[pos] = self.HEX_LOOKUP[j]
                pos += 1
            return result_array
        else:
            result = bytearray(len(source_bytes) * 2)
            pos = 0
            
            for byte_val in source_bytes:
                c = int(byte_val & 0xFF)
                j = c >> 4
                result[pos] = self.HEX_LOOKUP[j]
                pos += 1
                j = c & 0xF
                result[pos] = self.HEX_LOOKUP[j]
                pos += 1
            
            return bytes(result)

    cpdef authenticate(self, str password, str random64code, str token, 
                           str server_signature="", int server_iteration=0, 
                           str method="sha256"):
        """Execute RFC5802 algorithm for SCRAM authentication.

        Implements the RFC5802 SCRAM authentication mechanism for secure
        password authentication following the protocol specification.

        Args:
            password (str): User password
            random64code (str): 64-character random hex string (salt)
            token (str): Authentication token in hex format
            server_signature (str, optional): Server signature for verification.
                Default is empty string.
            server_iteration (int, optional): PBKDF2 iteration count. 
                Default is DEFAULT_ITERATIONS (4096).
            method (str, optional): Hash algorithm ('sha256' or 'sm3'). 
                Default is 'sha256'.

        Returns:
            bytes: Client proof as hex bytes, or empty bytes on verification failure

        Raises:
            ValueError: If RFC5802 algorithm execution fails or invalid parameters

        Examples:
            >>> auth = RFC5802Authentication()
            >>> result = auth.authenticate("password", "0123...abcd", "token_hex")
            >>> len(result) > 0  # Should return non-empty bytes on success
            True
        """
        cdef bytes k, server_key, client_key, stored_key, token_byte
        cdef bytes client_signature, hmac_result, h_value
        cdef bytearray result
        cdef int h_value_len
        
        if server_iteration:
            server_iteration = self.DEFAULT_ITERATIONS
            
        if method.lower() not in self.SUPPORTED_METHODS:
            raise ValueError(f"Unsupported hash method: {method}")
        
        try:
            # Step 1: Generate K (SaltedPassword)
            k = self._generate_k_from_pbkdf2(password, random64code, server_iteration)
            
            # Step 2: Generate ServerKey and ClientKey
            # Note: Keeping "Sever Key" spelling as per original implementation
            server_key = self._get_key_from_hmac(k, b"Sever Key")
            client_key = self._get_key_from_hmac(k, b"Client Key")
            
            # Step 3: Generate StoredKey
            if method.lower() == "sha256":
                stored_key = self._get_sha256(client_key)
            elif method.lower() == "sm3":
                stored_key = self._get_sm3(client_key)
            else:
                stored_key = self._get_sha256(client_key)  # Default to SHA256
            
            # Step 4: Convert token to bytes
            token_byte = self.hex_string_to_bytes(token)
            
            # Step 5: Calculate clientSignature (actually ServerSignature for verification)
            client_signature = self._get_key_from_hmac(server_key, token_byte)
            
            # Step 6: Verify serverSignature if provided
            if (server_signature and 
                server_signature != self._bytes_to_hex_string(client_signature)):
                return b""
            
            # Step 7: Calculate actual ClientSignature
            hmac_result = self._get_key_from_hmac(stored_key, token_byte)
            
            # Step 8: XOR operation to get ClientProof
            h_value = self._xor_between_password(hmac_result, client_key, len(client_key))
            
            # Step 9: Convert to hex bytes format
            # Corresponds to Java's bytesToHex(hValue, result, 0, hValue.length)
            h_value_len = len(h_value)
            result = bytearray(h_value_len * 2)
            self._bytes_to_hex(h_value, result, 0, h_value_len)
            
            return bytes(result)
            
        except Exception as e:
            raise ValueError(f"RFC5802 algorithm execution failed: {e}")
